'''
    ModelNet dataset. Support ModelNet40, ModelNet10, XYZ and normal channels. Up to 10000 points.
'''

import os
import os.path
import json
import numpy as np
import math
import sys
import torch
import vgtk.so3conv.functional as L
from scipy.spatial.transform import Rotation as sciR
from SPConvNets.datasets.part_transform import revoluteTransform
from SPConvNets.models.model_util import *

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
ROOT_DIR = BASE_DIR
sys.path.append(os.path.join(ROOT_DIR, 'utils'))
# import provider
from torch.utils import data
# from SPConvNets.models.common_utils import *
from SPConvNets.datasets.data_utils import *
import scipy.io as sio
import copy
# from model.utils import farthest_point_sampling

import os
os.environ["PYOPENGL_PLATFORM"] = "osmesa"
from scipy.spatial.transform import Rotation as sciR
from SPConvNets.datasets.part_transform import revoluteTransform
from SPConvNets.models.model_util import *
import random

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
ROOT_DIR = BASE_DIR
sys.path.append(os.path.join(ROOT_DIR, 'utils'))
# import provider
from torch.utils import data
from SPConvNets.models.common_utils import *
from SPConvNets.datasets.data_utils import *
import scipy.io as sio
import copy
import trimesh
import pyrender

# padding 1
def padding_1(pos):
    pad = np.array([1.], dtype=np.float).reshape(1, 1)
    # print(pos.shape, pad.shape)
    return np.concatenate([pos, pad], axis=1)

# normalize point-cloud
def pc_normalize(pc):
    l = pc.shape[0]
    centroid = np.mean(pc, axis=0)
    pc = pc - centroid
    m = np.max(np.sqrt(np.sum(pc ** 2, axis=1)))
    pc = pc / m
    return pc


def reindex_triangles_vertices(vertices, triangles):
    vert_idx_to_new_idx = {}
    new_idx = 0
    for i_tri in range(triangles.shape[0]):
        cur_tri = triangles[i_tri]
        v1, v2, v3 = int(cur_tri[0].item()), int(cur_tri[1].item()), int(cur_tri[2].item())
        for v in [v1, v2, v3]:
            if v not in vert_idx_to_new_idx:
                vert_idx_to_new_idx[v] = new_idx
                new_idx = new_idx + 1
    new_vertices = np.zeros((new_idx, 3), dtype=np.float)
    for vert_idx in vert_idx_to_new_idx:
        new_vert_idx = vert_idx_to_new_idx[vert_idx]
        new_vertices[new_vert_idx] = vertices[vert_idx]
    new_triangles = np.zeros_like(triangles)
    for i_tri in range(triangles.shape[0]):
        cur_tri = triangles[i_tri]
        v1, v2, v3 = int(cur_tri[0].item()), int(cur_tri[1].item()), int(cur_tri[2].item())
        new_v1, new_v2, new_v3 = vert_idx_to_new_idx[v1], vert_idx_to_new_idx[v2], vert_idx_to_new_idx[v3]
        cur_new_tri = np.array([new_v1, new_v2, new_v3], dtype=np.long)
        new_triangles[i_tri] = cur_new_tri
    # new_triangles: tot_n_tri x 3
    # new_vertices: tot_n_vert x 3
    return new_vertices, new_triangles

def ndc_depth_to_buffer(z, near, far):  # z in [-1, 1]
    return 2 * near * far / (near + far - z * (far - near))


def buffer_depth_to_ndc(d, near, far):  # d in (0, +
    return ((near + far) - 2 * near * far / np.clip(d, a_min=1e-6, a_max=1e6)) / (far - near)


# decod rotation info
def decode_rotation_info(rotate_info_encoding):
    if rotate_info_encoding == 0:
        return []
    rotate_vec = []
    if rotate_info_encoding <= 3:
        temp_angle = np.reshape(np.array(np.random.rand(3)) * np.pi, (3, 1))
        if rotate_info_encoding == 1:
            line_vec = np.concatenate([
                np.cos(temp_angle), np.zeros_like(temp_angle), np.sin(temp_angle),
            ], axis=-1)
        elif rotate_info_encoding == 2:
            line_vec = np.concatenate([
                np.cos(temp_angle), np.sin(temp_angle), np.zeros_like(temp_angle)
            ], axis=-1)
        else:
            line_vec = np.concatenate([
                np.zeros_like(temp_angle), np.cos(temp_angle), np.sin(temp_angle)
            ], axis=-1)
        return [line_vec[0], line_vec[1], line_vec[2]]
    elif rotate_info_encoding <= 6:
        base_rotate_vec = [np.array([1.0, 0.0, 0.0], dtype=np.float),
                           np.array([0.0, 1.0, 0.0], dtype=np.float),
                           np.array([0.0, 0.0, 1.0], dtype=np.float)]
        if rotate_info_encoding == 4:
            return [base_rotate_vec[0], base_rotate_vec[2]]
        elif rotate_info_encoding == 5:
            return [base_rotate_vec[0], base_rotate_vec[1]]
        else:
            return [base_rotate_vec[1], base_rotate_vec[2]]
    else:
        return []


def rotate_by_vec_pts(un_w, p_x, bf_rotate_pos):

    def get_zero_distance(p, xyz):
        k1 = np.sum(p * xyz).item()
        k2 = np.sum(xyz ** 2).item()
        t = -k1 / (k2 + 1e-10)
        p1 = p + xyz * t
        # dis = np.sum(p1 ** 2).item()
        return np.reshape(p1, (1, 3))

    w = un_w / np.sqrt(np.sum(un_w ** 2, axis=0))
    # w = np.array([0, 0, 1.0])
    w_matrix = np.array(
        [[0, -float(w[2]), float(w[1])], [float(w[2]), 0, -float(w[0])], [-float(w[1]), float(w[0]), 0]]
    )

    rng = 0.25
    offset = 0.1

    effi = np.random.uniform(-rng, rng, (1,)).item()
    # effi = effis[eff_id].item()
    if effi < 0:
        effi -= offset
    else:
        effi += offset
    theta = effi * np.pi
    # rotation_matrix = np.exp(w_matrix * theta)

    sin_theta = np.sin(theta)
    cos_theta = np.cos(theta)

    # rotation_matrix = np.eye(3) + w_matrix * sin_theta + (w_matrix ** 2) * (1. - cos_theta)
    rotation_matrix = np.eye(3) + w_matrix * sin_theta + (w_matrix.dot(w_matrix)) * (1. - cos_theta)

    # bf_rotate_pos = pcd_points[sem_label_to_idxes[rotate_idx][rotate_idx_inst]]

    trans = get_zero_distance(p_x, un_w)

    af_rotate_pos = np.transpose(np.matmul(rotation_matrix, np.transpose(bf_rotate_pos - trans, [1, 0])), [1, 0]) + trans

    # af_rotate_pos = rotation_matrix.dot((bf_rotate_pos - trans).T).T + trans
    return af_rotate_pos, rotation_matrix, np.reshape(trans, (3, 1))


class MotionDataset(data.Dataset):
    def __init__(
            self, root="./data/MDV02", npoints=512, split='train', nmask=10, shape_type="laptop", args=None, global_rot=0
    ):
        super(MotionDataset, self).__init__()

        self.root = root
        self.npoints = npoints
        self.shape_type = shape_type
        self.shape_root = os.path.join(self.root, shape_type)
        self.args = args

        self.mesh_fn = "summary.obj"
        self.surface_to_seg_fn = "sfs_idx_to_dof_name_idx.npy"
        self.attribute_fn = "motion_attributes.json"

        self.global_rot = global_rot
        self.split = split
        self.rot_factor = self.args.equi_settings.rot_factor
        self.no_articulation = self.args.equi_settings.no_articulation
        self.pre_compute_delta = self.args.equi_settings.pre_compute_delta
        self.use_multi_sample = self.args.equi_settings.use_multi_sample
        self.n_samples = self.args.equi_settings.n_samples if self.use_multi_sample == 1 else 1

        self.dataset_root = root
        self.dataset_root = os.path.join(self.dataset_root, self.split)

        if self.pre_compute_delta == 1 and self.split == 'train':
            # self.use_multi_sample = False
            self.use_multi_sample = 0
            self.n_samples = 1

        self.n_samples = self.n_samples * 10
        print(f"no_articulation: {self.no_articulation}, equi_settings: {self.args.equi_settings.no_articulation}")

        self.train_ratio = 0.9

        self.shape_folders = []

        self.pts_folders = []
        self.cfg_folders = []
        self.shape_indexes = []
        self.rendered_folders = []

        shape_idxes = os.listdir(self.dataset_root)
        for shp_idx in shape_idxes:
            cur_shp_folder = os.path.join(self.dataset_root, shp_idx)
            self.shape_indexes.append(shp_idx)
            for i in range(self.n_samples):
                pts_fn = os.path.join(cur_shp_folder, f"sample_points{i}.npy")
                cfg_fn = os.path.join(cur_shp_folder, f"sample_cfg{i}.npy")
                rendered_fn = os.path.join(cur_shp_folder, f"rendered{i}.npy")
                self.pts_folders.append(pts_fn)
                self.cfg_folders.append(cfg_fn)
                self.rendered_folders.append(rendered_fn)

        ''' Set shape indexes --- the number of shapes but not the number of all samples '''
        self.shape_idxes = shape_idxes

        self.anchors = L.get_anchors(args.model.kanchor)

        # self.anchors = torch.from_numpy(L.get_anchors(args.model.kanchor)).cuda()
        self.kanchor = args.model.kanchor


    def get_trans_encoding_to_trans_dir(self):
        trans_dir_to_trans_mode = {
            (0, 1, 2): 1,
            (0, 2): 2,
            (0, 1): 3,
            (1, 2): 4,
            (0,): 5,
            (1,): 6,
            (2,): 7,
            (): 0
        }
        self.trans_mode_to_trans_dir = {trans_dir_to_trans_mode[k]: k for k in trans_dir_to_trans_mode}
        self.base_transition_vec = [
            np.array([1.0, 0.0, 0.0], dtype=np.float),
            np.array([0.0, 1.0, 0.0], dtype=np.float),
            np.array([0.0, 0.0, 1.0], dtype=np.float),
        ]

    def get_test_data(self):
        return self.test_data

    def reindex_data_index(self):
        old_idx_to_new_idx = {}
        ii = 0
        for old_idx in self.data:
            old_idx_to_new_idx[old_idx] = ii
            ii += 1
        self.new_idx_to_old_idx = {old_idx_to_new_idx[k]: k for k in old_idx_to_new_idx}

    def reindex_shape_seg(self, shape_seg):
        old_seg_to_new_seg = {}
        ii = 0
        for i in range(shape_seg.shape[0]):
            old_seg = int(shape_seg[i].item())
            if old_seg not in old_seg_to_new_seg:
                old_seg_to_new_seg[old_seg] = ii
                ii += 1
            new_seg = old_seg_to_new_seg[old_seg]
            shape_seg[i] = new_seg
            # old_seg_to_new_seg[old_seg] = ii
            # ii += 1
        return shape_seg

    def transit_pos_by_transit_vec(self, trans_pos):
        tdir = np.random.uniform(-1.0, 1.0, (3,))
        tdir = tdir / (np.sqrt(np.sum(tdir ** 2)).item() + 1e-9)
        trans_scale = np.random.uniform(1.0, 2.0, (1,)).item()
        # for flow test...
        trans_pos_af_pos = trans_pos + tdir * 0.1 * trans_scale * 2
        return trans_pos_af_pos, tdir * 0.1 * trans_scale * 2

    def transit_pos_by_transit_vec_dir(self, trans_pos, tdir):
        # tdir = np.zeros((3,), dtype=np.float)
        # axis_dir = np.random.choice(3, 1).item()
        # tdir[int(axis_dir)] = 1.
        trans_scale = np.random.uniform(0.0, 1.0, (1,)).item()
        trans_pos_af_pos = trans_pos + tdir * 0.1 * trans_scale
        return trans_pos_af_pos, tdir * 0.1 * trans_scale

    def get_random_transition_dir_scale(self):
        tdir = np.random.uniform(-1.0, 1.0, (3,))
        tdir = tdir / (np.sqrt(np.sum(tdir ** 2)).item() + 1e-9)
        trans_scale = np.random.uniform(1.0, 2.0, (1,)).item()
        return tdir, trans_scale

    def decode_trans_dir(self, trans_encoding):
        trans_dir = self.trans_mode_to_trans_dir[trans_encoding]
        return [self.base_transition_vec[d] for ii, d in enumerate(trans_dir)]

    def get_rotation_from_anchor(self):
        ii = np.random.randint(0, self.kanchor, (1,)).item()
        ii = int(ii)
        R = self.anchors[ii]
        return R

    def get_whole_shape_by_idx(self, index):
        shp_idx = self.shape_idxes[index + 1]
        cur_folder = os.path.join(self.shape_root, shp_idx)

        cur_mesh_fn = os.path.join(cur_folder, self.mesh_fn)
        cur_surface_to_seg_fn = os.path.join(cur_folder, self.surface_to_seg_fn)
        cur_motion_attributes_fn = os.path.join(cur_folder, self.attribute_fn)

        cur_vertices, cur_triangles = load_vertices_triangles(cur_mesh_fn)
        cur_triangles_to_seg_idx, seg_idx_to_triangle_idxes = load_triangles_to_seg_idx(cur_surface_to_seg_fn)
        cur_motion_attributes = load_motion_attributes(cur_motion_attributes_fn)

        sampled_pcts, pts_to_seg_idx, seg_idx_to_sampled_pts = sample_pts_from_mesh(cur_vertices, cur_triangles,
                                                                                    cur_triangles_to_seg_idx,
                                                                                    npoints=self.npoints)
        sampled_pcts = torch.from_numpy(sampled_pcts).float()
        return sampled_pcts

    def get_shape_by_idx(self, index):
        shp_idx = self.shape_idxes[index + 1]
        cur_folder = os.path.join(self.shape_root, shp_idx)

        cur_mesh_fn = os.path.join(cur_folder, self.mesh_fn)
        cur_surface_to_seg_fn = os.path.join(cur_folder, self.surface_to_seg_fn)
        cur_motion_attributes_fn = os.path.join(cur_folder, self.attribute_fn)

        cur_vertices, cur_triangles = load_vertices_triangles(cur_mesh_fn)
        cur_triangles_to_seg_idx, seg_idx_to_triangle_idxes = load_triangles_to_seg_idx(cur_surface_to_seg_fn)
        cur_motion_attributes = load_motion_attributes(cur_motion_attributes_fn)

        sampled_pcts, pts_to_seg_idx, seg_idx_to_sampled_pts = sample_pts_from_mesh(cur_vertices, cur_triangles,
                                                                                    cur_triangles_to_seg_idx,
                                                                                    npoints=self.npoints)

        # get points for each segmentation/part
        tot_transformed_pts = []
        pts_nns = []
        for i_seg in range(len(cur_motion_attributes)):
            cur_seg_motion_info = cur_motion_attributes[i_seg]
            cur_seg_pts_idxes = np.array(seg_idx_to_sampled_pts[i_seg], dtype=np.long)
            cur_seg_pts = sampled_pcts[cur_seg_pts_idxes]
            pts_nns.append(cur_seg_pts.shape[0])

            tot_transformed_pts.append(cur_seg_pts)
        maxx_nn_pt = max(pts_nns)
        res_pts = []
        for i, trans_pts in enumerate(tot_transformed_pts):
            cur_seg_nn_pt = trans_pts.shape[0]
            cur_seg_center_pt = np.mean(trans_pts, axis=0, keepdims=True)
            if cur_seg_nn_pt < maxx_nn_pt:
                cur_seg_trans_pts = np.concatenate(
                    [trans_pts] + [cur_seg_center_pt for _ in range(maxx_nn_pt - cur_seg_nn_pt)], axis=0
                )
                res_pts.append(np.reshape(cur_seg_trans_pts, (1, maxx_nn_pt, 3)))
            else:
                res_pts.append(np.reshape(trans_pts, (1, maxx_nn_pt, 3)))

        res_pts = np.concatenate(res_pts, axis=0)
        res_pts = torch.from_numpy(res_pts).float()
        return res_pts

    def refine_triangle_idxes_by_seg_idx(self, seg_idx_to_triangle_idxes, cur_triangles):
        res_triangles = []
        cur_triangles_to_seg_idx = []
        for seg_idx in seg_idx_to_triangle_idxes:
            # if seg_idx == 0:
            #     continue
            cur_triangle_idxes = np.array(seg_idx_to_triangle_idxes[seg_idx], dtype=np.long)
            cur_seg_triangles = cur_triangles[cur_triangle_idxes]
            res_triangles.append(cur_seg_triangles)
            cur_triangles_to_seg_idx += [seg_idx for _ in range(cur_triangle_idxes.shape[0])]
        res_triangles = np.concatenate(res_triangles, axis=0)
        cur_triangles_to_seg_idx = np.array(cur_triangles_to_seg_idx, dtype=np.long)
        return res_triangles, cur_triangles_to_seg_idx

    def __getitem__(self, index):

        # n_samples_per_instance = 100

        def get_seg_labels_to_pts_idxes(seg_labels):
            seg_label_to_pts_idxes = {}
            for i_pts in range(seg_labels.shape[0]):
                cur_pts_label = int(seg_labels[i_pts].item())
                if cur_pts_label not in seg_label_to_pts_idxes:
                    seg_label_to_pts_idxes[cur_pts_label] = [i_pts]
                else:
                    seg_label_to_pts_idxes[cur_pts_label].append(i_pts)
            for cur_seg_label in seg_label_to_pts_idxes:
                seg_label_to_pts_idxes[cur_seg_label] = np.array(seg_label_to_pts_idxes[cur_seg_label], dtype=np.long)

            return seg_label_to_pts_idxes

        shape_index = index
        cur_pts_fn = self.pts_folders[shape_index]
        cur_cfg_fn = self.cfg_folders[shape_index]

        sample_idx = index % 100
        shape_idx = index // 100

        cur_pts = np.load(cur_pts_fn, allow_pickle=True) # load current points... # load points
        # npts x 4
        cur_pts = np.transpose(cur_pts, (1, 0)) #
        cur_pts, cur_labels = cur_pts[:, :3], cur_pts[:, 3].astype(np.long) # labels
        cur_seg_label_to_pts_idxes = get_seg_labels_to_pts_idxes(cur_labels)


        cur_cfg = np.load(cur_cfg_fn, allow_pickle=True).item()
        cur_seg_label_to_motion_attr = {}
        cur_seg_label_to_vertices = {} # vertices idxes
        cur_seg_label_to_triangles = {}
        tot_vertices = []
        tot_triangels = []
        tot_n_triangles = 0
        tot_n_vertices = 0
        cur_n_part = 0
        print(f"cfg.keys: {cur_cfg.keys()}")
        for seg_name in cur_cfg:
            motion_attrs = cur_cfg[seg_name]
            # cur_seg_label = int(motion_attrs['label'].item())
            cur_seg_label = int(seg_name[-1])
            cur_motion_type = motion_attrs['motion_type']
            if cur_motion_type != 'none_motion':
                cur_dir = motion_attrs['axis']['direction']
                cur_limit = motion_attrs['limit']
                cur_limit_a, cur_limit_b = float(cur_limit['a']), float(cur_limit['b'])
                cur_state = float(motion_attrs['state'])
                cur_seg_label_to_motion_attr[cur_seg_label] = {
                    'dir': cur_dir, 'a': cur_limit_a, 'b': cur_limit_b, 'state': cur_state
                }
            cur_part_sampled_triangles_fn = os.path.join(self.dataset_root, self.shape_idxes[shape_idx], f"sample_mesh{sample_idx}_{cur_n_part}.obj")
            cur_part_vertices, cur_part_triangles =  load_vertices_triangles(cur_part_sampled_triangles_fn)
            cur_part_vertices_idxes = np.array([tot_n_vertices + _ for _ in range(cur_part_vertices.shape[0])], dtype=np.long)
            cur_seg_label_to_vertices[cur_seg_label] = cur_part_vertices_idxes
            cur_part_triangles_idxes = np.array([tot_n_triangles + _ for _ in range(cur_part_triangles.shape[0])], dtype=np.long)
            cur_seg_label_to_triangles[cur_seg_label] = cur_part_triangles_idxes

            tot_vertices.append(cur_part_vertices)
            tot_triangels.append(cur_part_triangles + tot_n_vertices)
            tot_n_vertices += cur_part_vertices.shape[0]
            tot_n_triangles += cur_part_triangles.shape[0]
            cur_n_part += 1

        tot_vertices = np.concatenate(tot_vertices, axis=0)
        tot_triangels = np.concatenate(tot_triangels, axis=0)

        ori_pts = np.zeros_like(cur_pts)

        cur_partial_data_fn = self.rendered_folders[shape_index]
        # cur_motion_attributes = load_motion_attributes(cur_motion_attributes_fn, ex_none=True)
        cur_rendered_data = np.load(cur_partial_data_fn, allow_pickle=True).item()

        seg_label_to_pts, seg_label_to_new_pose, glb_pose, part_axis, part_angles = cur_rendered_data['seg_label_to_pts'], cur_rendered_data['seg_label_to_new_pose'], \
                                                                                                   cur_rendered_data['glb_pose'], cur_rendered_data['part_axis'], cur_rendered_data['part_angles']
        canon_seg_label_to_pts, canon_seg_label_to_new_pose, canon_glb_pose = cur_rendered_data['canon_seg_label_to_pts'], cur_rendered_data['canon_seg_label_to_new_pose'], \
                                                                              cur_rendered_data['canon_glb_pose']


        ''' Get per-point pose and per-part pose '''
        part_state_orts = np.zeros((4, 3, 3), dtype=np.float)
        # part_ref_rots = np.zeros((4, 3, 3), dtype=np.float)

        part_axis = np.concatenate(part_axis, axis=0)

        cur_n_seg = len(seg_label_to_pts)
        tot_transformation_mtx_segs = np.zeros((cur_n_seg, 4, 4), dtype=np.float)
        # part_state_rots = np.zeros((cur_n_seg, 3, 3), dtype=np.float)
        # part_state_trans_bbox = np.zeros((cur_n_seg, 3), dtype=np.float)
        part_ref_rots = np.zeros((cur_n_seg, 3, 3), dtype=np.float)
        part_ref_trans = np.zeros((cur_n_seg, 3), dtype=np.float)
        part_ref_trans_bbox = np.zeros((cur_n_seg, 3), dtype=np.float)


        tot_transformed_full_vertices = []
        canon_tot_transformed_full_vertices = []
        tot_transformed_pts = []
        pts_to_seg_idx = []
        tot_transformation_mtx = []
        # seg label to
        canon_transformed_pts = []
        part_state_trans_bbox = np.zeros_like(part_ref_trans_bbox)
        glb_rotation, glb_trans = glb_pose['rotation'], glb_pose['trans']
        part_axis = np.matmul(glb_rotation, np.transpose(part_axis, (1, 0)))
        part_axis = np.transpose(part_axis, (1, 0))

        rotated_ori_pts = np.transpose(np.matmul(glb_rotation, np.transpose(ori_pts, (1, 0))), (1, 0)) + np.reshape(glb_trans, (1, 3))

        for seg_label in seg_label_to_pts:  #

            cur_seg_trans_pts = seg_label_to_pts[seg_label]
            cur_seg_pose = seg_label_to_new_pose[seg_label]
            canon_cur_seg_pose = canon_seg_label_to_new_pose[seg_label]
            tot_transformed_pts.append(cur_seg_trans_pts)  # partial transformed points
            pts_to_seg_idx += [seg_label for _ in range(cur_seg_trans_pts.shape[0])]

            canon_transformed_pts.append(canon_seg_label_to_pts[seg_label])

            ''' Get triangle idxes for this segmentation '''
            cur_seg_tri_idxes = cur_seg_label_to_triangles[seg_label]
            cur_seg_tri = tot_triangels[cur_seg_tri_idxes]
            cur_seg_tri_v1, cur_seg_tri_v2, cur_seg_tri_v3 = tot_vertices[cur_seg_tri[:, 0]], tot_vertices[
                cur_seg_tri[:, 1]], tot_vertices[cur_seg_tri[:, 2]]
            cur_seg_tri_vertices = np.concatenate([cur_seg_tri_v1, cur_seg_tri_v2, cur_seg_tri_v3], axis=0)
            cur_seg_tri_vertices_idxes = np.concatenate([cur_seg_tri[:, 0], cur_seg_tri[:, 1], cur_seg_tri[:, 2]],
                                                        axis=0)

            ''' Get transformed full vertices '''
            cur_seg_rot, cur_seg_trans = cur_seg_pose[:3, :3], cur_seg_pose[:3, 3]
            rot_cur_seg_tri_vertices = np.matmul(cur_seg_rot, np.transpose(cur_seg_tri_vertices, (1, 0))) + np.reshape(
                cur_seg_trans, (3, 1))
            rot_cur_seg_tri_vertices = np.transpose(rot_cur_seg_tri_vertices, (1, 0))

            tot_transformed_full_vertices.append(rot_cur_seg_tri_vertices)
            ''' Get transformed full vertices '''

            ''' Get  '''
            canon_cur_seg_rot, canon_cur_seg_trans = canon_cur_seg_pose[:3, :3], canon_cur_seg_pose[:3, 3]
            canon_rot_cur_seg_tri_vertices = np.matmul(canon_cur_seg_rot, np.transpose(cur_seg_tri_vertices, (1, 0))) + np.reshape(
                canon_cur_seg_trans, (3, 1))
            canon_rot_cur_seg_tri_vertices = np.transpose(canon_rot_cur_seg_tri_vertices, (1, 0))

            canon_tot_transformed_full_vertices.append(canon_rot_cur_seg_tri_vertices)

            # cur_seg_pose: 3 x 3; 3

            cur_seg_pts_minn = np.min(rot_cur_seg_tri_vertices, axis=0)
            cur_seg_pts_maxx = np.max(rot_cur_seg_tri_vertices, axis=0)
            cur_seg_pts_bbox_center = (cur_seg_pts_minn + cur_seg_pts_maxx) / 2.
            cur_seg_trans_bbox = cur_seg_trans - cur_seg_pts_bbox_center

            part_state_trans_bbox[seg_label] = cur_seg_trans_bbox
            part_state_orts[seg_label] = cur_seg_rot
            tot_transformation_mtx_segs[seg_label] = cur_seg_pose
            # tot_transformation_mtx.append([cur_seg_pose for _ in range(cur_seg_trans_pts.shape[0])])
            tot_transformation_mtx += [np.reshape(cur_seg_pose, (1, 4, 4)) for _ in range(cur_seg_trans_pts.shape[0])]

        tot_transformed_pts = np.concatenate(tot_transformed_pts, axis=0)
        tot_transformed_full_vertices = np.concatenate(tot_transformed_full_vertices, axis=0)
        canon_tot_transformed_full_vertices = np.concatenate(canon_tot_transformed_full_vertices, axis=0)
        # center_transformed_full_vertices = np.mean(tot_transformed_full_vertices, axis=0, keepdims=True)
        tot_transformation_mtx = np.concatenate(tot_transformation_mtx, axis=0)
        canon_transformed_pts = np.concatenate(canon_transformed_pts, axis=0)
        pts_to_seg_idx = np.array(pts_to_seg_idx, dtype=np.long)

        gt_pose = tot_transformation_mtx

        af_glb_center_pt = np.mean(rotated_ori_pts, axis=0)
        tot_transformed_pts = (tot_transformed_pts - af_glb_center_pt.reshape(1, 3))
        ''' Point normalization via centralization '''

        gt_pose[:, :3, 3] = gt_pose[:, :3, 3] - af_glb_center_pt  # transformation matrix...
        tot_transformation_mtx_segs[:, :3, 3] = tot_transformation_mtx_segs[:, :3, 3] - af_glb_center_pt


        # cur_pc = torch.from_numpy(cur_pts.astype(np.float32)).float()
        cur_pc = torch.from_numpy(tot_transformed_pts.astype(np.float32)).float()
        tot_transformed_pts = torch.from_numpy(tot_transformed_pts.astype(np.float32)).float()
        # cur_label = torch.from_numpy(cur_labels).long()
        # tot_label = torch.from_numpy(cur_labels).long()
        cur_label = torch.from_numpy(pts_to_seg_idx).long()
        tot_label = torch.from_numpy(pts_to_seg_idx).long()
        cur_pose = torch.from_numpy(gt_pose.astype(np.float32))
        cur_pose_segs = torch.from_numpy(tot_transformation_mtx_segs.astype(np.float32))
        cur_ori_pc = torch.from_numpy(ori_pts.astype(np.float32)).float()
        cur_canon_transformed_pts = torch.from_numpy(canon_transformed_pts.astype(np.float32)).float()
        cur_part_state_rots = torch.from_numpy(part_state_orts.astype(np.float32)).float()
        cur_part_ref_rots = torch.from_numpy(part_ref_rots.astype(np.float32)).float()
        cur_part_ref_trans = torch.from_numpy(part_ref_trans.astype(np.float32)).float()
        part_ref_trans_bbox = torch.from_numpy(part_ref_trans_bbox.astype(np.float32)).float()
        part_state_trans_bbox = torch.from_numpy(part_state_trans_bbox.astype(np.float32)).float()
        cur_part_axis = torch.from_numpy(part_axis.astype(np.float32)).float()

        fps_idx = farthest_point_sampling(cur_pc.unsqueeze(0), n_sampling=self.npoints)
        fps_idx_oorr = farthest_point_sampling(cur_pc.unsqueeze(0), n_sampling=4096)
        tot_transformed_pts = tot_transformed_pts[fps_idx_oorr]
        tot_label = tot_label[fps_idx_oorr]
        cur_pc = cur_pc[fps_idx]
        cur_label = cur_label[fps_idx]
        cur_pose = cur_pose[fps_idx]
        # cur_pose = cur_pose[fps_idx]

        # cur_ori_pc = cur_ori_pc[fps_idx]
        canon_fps_idx = farthest_point_sampling(cur_canon_transformed_pts.unsqueeze(0), n_sampling=self.npoints)
        canon_fps_idx = canon_fps_idx[:self.npoints]
        canon_fps_idx_oorr = farthest_point_sampling(cur_canon_transformed_pts.unsqueeze(0), n_sampling=4096)
        cur_oorr_canon_transformed_pts = cur_canon_transformed_pts[canon_fps_idx_oorr]
        cur_canon_transformed_pts = cur_canon_transformed_pts[canon_fps_idx]

        idx_arr = np.array([index], dtype=np.long)
        idx_arr = torch.from_numpy(idx_arr).long()


        rt_dict = {
            'pc': cur_pc.contiguous().transpose(0, 1).contiguous(),
            'af_pc': cur_pc.contiguous().transpose(0, 1).contiguous(),
            'ori_pc': tot_transformed_pts.contiguous().transpose(0, 1).contiguous(),
            'canon_pc': cur_canon_transformed_pts, #.contiguous().transpose(0, 1).contiguous(),
            'oorr_pc': tot_transformed_pts.contiguous().transpose(0, 1).contiguous(),
            'oorr_canon_pc': cur_oorr_canon_transformed_pts.contiguous(),
            'label': cur_label,
            'oorr_label': tot_label,
            'pose': cur_pose,
            'pose_segs': cur_pose_segs,
            'part_state_rots': cur_part_state_rots,
            'part_ref_rots': cur_part_ref_rots,
            'part_ref_trans': cur_part_ref_trans,
            'idx': idx_arr,
            'part_state_trans_bbox': part_state_trans_bbox,
            'part_ref_trans_bbox': part_ref_trans_bbox,
            'part_axis': cur_part_axis
        }

        return rt_dict

    def __len__(self):
        return len(self.shape_idxes) * self.n_samples

    def get_num_moving_parts_to_cnt(self):
        return self.num_mov_parts_to_cnt

    def reset_num_moving_parts_to_cnt(self):
        self.num_mov_parts_to_cnt = {}


if __name__ == '__main__':
    d = ModelNetDataset(root='../data/modelnet40_normal_resampled', split='test')
    print(d.shuffle)
    print(len(d))
    import time

    tic = time.time()
    for i in range(10):
        ps, cls = d[i]
    print(time.time() - tic)
    print(ps.shape, type(ps), cls)

    print(d.has_next_batch())
    ps_batch, cls_batch = d.next_batch(True)
    print(ps_batch.shape)
    print(cls_batch.shape)
